#####
# Script to generate figure 2.
#####

library(ggplot2)
library(cowplot)
library(patchwork)
library(viridis)
library(tidyverse)
library(pracma)
library(nleqslv)
library(rootSolve)
library(fields)
library(doParallel)
library(scales)
library(data.table)

source("equations.r")
source("simulation_functions.r")
source("planner_simulation.r")

library(simFnsEqns)

avg_sat_decay <- 1 # 1: satellites are infinitely lived
p <- 1
F <- 35
discount <- 0.99
r <- ((1 - discount)/(discount))
fe_eqm <- p/F - r #the equilibrium risk
d <- 0.5
m <- 0
aSS <- 0.005
aSD <- 0.005
aDD <- 0.025 
bSS <- 10
bSD <- 5.375
bDD <- 0.9 # 0.075
T <- 250
label <- "benchmark"

## Calibration scale factors.
s_scale_factor <- 68289.84/3.728625
d_scale_factor <- 367372.3/4.214421

params <- data.frame(avg_sat_decay=avg_sat_decay, p=p, F=F, discount=discount, r=r, fe_eqm=fe_eqm, d=d, m=m, aSS=aSS, aSD=aSD, aDD=aDD, bSS=bSS, bSD=bSD, bDD=bDD, T=T)
params2 <- params
cost_reduction_factor <- (params$p/params$F) * (1/(2*params$fe_eqm + params$r))
params2$F <- params$F*cost_reduction_factor
params2$fe_eqm <- params2$p/params2$F - params2$r
params$fe_eqm
params2$fe_eqm

init_launchpath <- rep(0, length.out=T)

opt_lp_lo <- optim(par = init_launchpath, fn = objective_fts_cpp, S = 0, D = 0, T=params$T, params = params, control = list(fnscale=-1), method = "L-BFGS-B", lower = 0, upper = 10)

opt_lp_hi <- optim(par = init_launchpath, fn = objective_fts_cpp, S = 0, D = 0, T=params$T, params = params2, control = list(fnscale=-1), method = "L-BFGS-B", lower = 0, upper = 10)

opt_lo <- cbind(time=seq(0:params$T),gen_fts(X=opt_lp_lo$par, S=0, D=0, T=params$T, params=params))

opt_hi <- cbind(time=seq(0:params2$T),gen_fts(X=opt_lp_hi$par, S=0, D=0, T=params2$T, params=params2))

fig_dfrm <- left_join( as.data.frame(opt_lo), as.data.frame(opt_hi), by=c("time"), suffix=c("_lo","_hi")) %>% 
			mutate(
				risk_lo = L(sats_lo, debs_lo, params), 
				risk_hi = L(sats_hi, debs_hi, params2),
				launches_lo = launches_lo*s_scale_factor,
				launches_hi = launches_hi*s_scale_factor,
				sats_lo = sats_lo*s_scale_factor,
				sats_hi = sats_hi*s_scale_factor,
				debs_lo = debs_lo*d_scale_factor,
				debs_hi = debs_hi*d_scale_factor
				)

head(fig_dfrm)

figure_1time_a <- ggplot(data=fig_dfrm[1:25,], aes(x=time)) +
				geom_line(aes(y=launches_hi), linetype="dashed", color="grey39", size=1) + 
				geom_line(aes(y=launches_lo), linetype="dotted", color="grey39", size=1) + 
				geom_segment(x=0,y=0, xend=35,yend=0, size=0.15) + geom_segment(x=0,y=0, xend=0,yend=5, size=0.15) +
				ylab("Launches") + xlab("Time") +
				theme_bw() + theme(text=element_text(family="Arial", size=15))

figure_1time_b <- ggplot(data=fig_dfrm[1:25,], aes(x=time)) +
				geom_line(aes(y=risk_hi), linetype="dashed", color="grey39", size=1.5) + 
				geom_line(aes(y=risk_lo), linetype="dotted", color="grey39", size=1.5) +  
				geom_hline(aes(yintercept=params2$fe_eqm), linetype="dashed", color="#7FC97F", size=1) +
				geom_hline(aes(yintercept=params$fe_eqm), linetype="dotted", color="#7FC97F", size=1) +
				ylab("Collision risk") + xlab("Time") +
				geom_segment(x=0,y=0, xend=35,yend=0, size=0.15) + geom_segment(x=0,y=0, xend=0,yend=5, size=0.15) +
				theme_bw() + theme(text=element_text(family="Arial", size=15))

figure_1time_c <- ggplot(data=fig_dfrm[1:25,], aes(x=time)) +
				geom_line(aes(y=sats_hi), linetype="dashed", color="grey39", size=1) + 
				geom_line(aes(y=sats_lo), linetype="dotted", color="grey39", size=1) +  
				ylab("Satellites") + xlab("Time") +
				geom_segment(x=0,y=0, xend=35,yend=0, size=0.15) + geom_segment(x=0,y=0, xend=0,yend=5, size=0.15) +
				theme_bw() + theme(text=element_text(family="Arial", size=15))


figure_1time_d <- ggplot(data=fig_dfrm[1:25,], aes(x=time)) +
				geom_line(aes(y=debs_hi), linetype="dashed", color="grey39", size=1) + 
				geom_line(aes(y=debs_lo), linetype="dotted", color="grey39", size=1) +  
				ylab("Debris") + xlab("Time") +
				geom_segment(x=0,y=0, xend=35,yend=0, size=0.15) + geom_segment(x=0,y=0, xend=0,yend=5, size=0.15) +
				theme_bw() + theme(text=element_text(family="Arial", size=15))


fig1_time_panel <- plot_grid(figure_1time_a, figure_1time_b, figure_1time_c, figure_1time_d, align="h", labels=c("a","b","c","d"), nrow=2)

ggsave(
	paste0("../images/figure-5--example-optimal-trajectories.png"),
	((figure_1time_a / figure_1time_c) | (figure_1time_b / figure_1time_d)) + plot_annotation(tag_levels = 'A') & plot_layout(guides = "collect") & theme(legend.position = 'right'),
	width = 16*2/3,
	height = 9*2/3,
	units = "in",
	dpi = 300)
